{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8456b5b6-185b-440b-ab98-1822aac2fe4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('documents-with-ids.json', 'rt') as f_in:\n",
    "    documents = json.load(f_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5833e987-21f5-4788-8609-d32410bc7b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3e482cd0-0202-4a89-854d-90b50c75e520",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'multi-qa-MiniLM-L6-cos-v1'\n",
    "model = SentenceTransformer(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "341b331f-e3aa-41b7-91a2-b42cd3a57c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = model.encode('I just discovered the course. Can I still join?')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5783a592-5d76-44ee-89a7-05615ef966bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "384"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "620ed5a1-cc06-40a3-8627-891aed525cba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'course-questions'})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "es_client = Elasticsearch('http://localhost:9200') \n",
    "\n",
    "index_settings = {\n",
    "    \"settings\": {\n",
    "        \"number_of_shards\": 1,\n",
    "        \"number_of_replicas\": 0\n",
    "    },\n",
    "    \"mappings\": {\n",
    "        \"properties\": {\n",
    "            \"text\": {\"type\": \"text\"},\n",
    "            \"section\": {\"type\": \"text\"},\n",
    "            \"question\": {\"type\": \"text\"},\n",
    "            \"course\": {\"type\": \"keyword\"},\n",
    "            \"id\": {\"type\": \"keyword\"},\n",
    "            \"question_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 384,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\"\n",
    "            },\n",
    "            \"text_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 384,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\"\n",
    "            },\n",
    "            \"question_text_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 384,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\"\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "index_name = \"course-questions\"\n",
    "\n",
    "es_client.indices.delete(index=index_name, ignore_unavailable=True)\n",
    "es_client.indices.create(index=index_name, body=index_settings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "91721a90-96d6-4ef9-8baf-a3d42c4b03e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ad8a4ba1-35d8-4650-9267-86afabef7386",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eb5b7e669e31453d80aef3a5c2fdbe2b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/948 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for doc in tqdm(documents):\n",
    "    question = doc['question']\n",
    "    text = doc['text']\n",
    "    qt = question + ' ' + text\n",
    "\n",
    "    doc['question_vector'] = model.encode(question)\n",
    "    doc['text_vector'] = model.encode(text)\n",
    "    doc['question_text_vector'] = model.encode(qt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d5b27abe-a0c2-426c-9780-3b1c0534dbe6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb4da719c1374c1ebd437dd5b3e53e1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/948 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for doc in tqdm(documents):\n",
    "    es_client.index(index=index_name, document=doc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "82a78984-1440-4195-9513-106109013fce",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = 'I just discovered the course. Can I still join it?'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8a5ac3e8-8d75-48f6-8d3c-0b25d227d1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "v_q = model.encode(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "ecba8e3b-7f61-402c-8106-62bde9eb1746",
   "metadata": {},
   "outputs": [],
   "source": [
    "def elastic_search_knn(field, vector, course):\n",
    "    knn = {\n",
    "        \"field\": field,\n",
    "        \"query_vector\": vector,\n",
    "        \"k\": 5,\n",
    "        \"num_candidates\": 10000,\n",
    "        \"filter\": {\n",
    "            \"term\": {\n",
    "                \"course\": course\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "\n",
    "    search_query = {\n",
    "        \"knn\": knn,\n",
    "        \"_source\": [\"text\", \"section\", \"question\", \"course\", \"id\"]\n",
    "    }\n",
    "\n",
    "    es_results = es_client.search(\n",
    "        index=index_name,\n",
    "        body=search_query\n",
    "    )\n",
    "    \n",
    "    result_docs = []\n",
    "    \n",
    "    for hit in es_results['hits']['hits']:\n",
    "        result_docs.append(hit['_source'])\n",
    "\n",
    "    return result_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "cc09b90a-a88b-4678-a613-eb68e16136d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def question_vector_knn(q):\n",
    "    question = q['question']\n",
    "    course = q['course']\n",
    "\n",
    "    v_q = model.encode(question)\n",
    "\n",
    "    return elastic_search_knn('question_vector', v_q, course)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "163d5e97-e1b7-45e1-ba1e-61fca4fc37a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e3da963b-cd17-471b-9cff-7f8c4276b915",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ground_truth = pd.read_csv('ground-truth-data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "d53d3beb-d699-46c8-bc85-953afc4ef48d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth = df_ground_truth.to_dict(orient='records')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "0deced29-bae5-4b6c-b2b8-53ce46b4f8ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'When does the course begin?',\n",
       " 'course': 'data-engineering-zoomcamp',\n",
       " 'document': 'c02e79ef'}"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ground_truth[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "8959c9ff-5bbe-4729-8fa3-cdc51ed10f5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hit_rate(relevance_total):\n",
    "    cnt = 0\n",
    "\n",
    "    for line in relevance_total:\n",
    "        if True in line:\n",
    "            cnt = cnt + 1\n",
    "\n",
    "    return cnt / len(relevance_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2970e2-4481-4378-8e29-4fe30369cd7f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "842255b5-18f2-4102-9689-a5835e0a621c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mrr(relevance_total):\n",
    "    total_score = 0.0\n",
    "\n",
    "    for line in relevance_total:\n",
    "        for rank in range(len(line)):\n",
    "            if line[rank] == True:\n",
    "                total_score = total_score + 1 / (rank + 1)\n",
    "\n",
    "    return total_score / len(relevance_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "f11baaff-43d9-4b8c-a896-561b86e85743",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(ground_truth, search_function):\n",
    "    relevance_total = []\n",
    "\n",
    "    for q in tqdm(ground_truth):\n",
    "        doc_id = q['document']\n",
    "        results = search_function(q)\n",
    "        relevance = [d['id'] == doc_id for d in results]\n",
    "        relevance_total.append(relevance)\n",
    "\n",
    "    return {\n",
    "        'hit_rate': hit_rate(relevance_total),\n",
    "        'mrr': mrr(relevance_total),\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "f1d3530e-1406-49dd-bba9-914f6a39d7f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f89bc2d454c4df48c39e8742bf362e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4627 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'hit_rate': 0.773071104387292, 'mrr': 0.6666810748505158}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate(ground_truth, question_vector_knn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3aa5eb82-27df-4577-a2e2-23971f2fa964",
   "metadata": {},
   "source": [
    "ES text only: 0.7395720769397017, 0.6032418413658963"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "e87f0987-02e9-4e2c-a091-14a75c9ce58d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def text_vector_knn(q):\n",
    "    question = q['question']\n",
    "    course = q['course']\n",
    "\n",
    "    v_q = model.encode(question)\n",
    "\n",
    "    return elastic_search_knn('text_vector', v_q, course)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "d676bd5e-4abc-4799-bb37-c1a86b3b9872",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b14f76431e374bf2beeba36608578533",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4627 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'hit_rate': 0.8286146531229739, 'mrr': 0.7062315395144454}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate(ground_truth, text_vector_knn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "13a47c3e-036a-4212-912c-a61de0daf6cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce2651726e11402c9c43b8adc72904f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4627 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'hit_rate': 0.9172249837907932, 'mrr': 0.824306606152295}"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def question_text_vector_knn(q):\n",
    "    question = q['question']\n",
    "    course = q['course']\n",
    "\n",
    "    v_q = model.encode(question)\n",
    "\n",
    "    return elastic_search_knn('question_text_vector', v_q, course)\n",
    "\n",
    "evaluate(ground_truth, question_text_vector_knn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "20f4d5f5-617d-4e89-8cc1-4acaed270be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def elastic_search_knn_combined(vector, course):\n",
    "    search_query = {\n",
    "        \"size\": 5,\n",
    "        \"query\": {\n",
    "            \"bool\": {\n",
    "                \"must\": [\n",
    "                    {\n",
    "                        \"script_score\": {\n",
    "                            \"query\": {\n",
    "                                \"term\": {\n",
    "                                    \"course\": course\n",
    "                                }\n",
    "                            },\n",
    "                            \"script\": {\n",
    "                                \"source\": \"\"\"\n",
    "                                    cosineSimilarity(params.query_vector, 'question_vector') + \n",
    "                                    cosineSimilarity(params.query_vector, 'text_vector') + \n",
    "                                    cosineSimilarity(params.query_vector, 'question_text_vector') + \n",
    "                                    1\n",
    "                                \"\"\",\n",
    "                                \"params\": {\n",
    "                                    \"query_vector\": vector\n",
    "                                }\n",
    "                            }\n",
    "                        }\n",
    "                    }\n",
    "                ],\n",
    "                \"filter\": {\n",
    "                    \"term\": {\n",
    "                        \"course\": course\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"_source\": [\"text\", \"section\", \"question\", \"course\", \"id\"]\n",
    "    }\n",
    "\n",
    "    es_results = es_client.search(\n",
    "        index=index_name,\n",
    "        body=search_query\n",
    "    )\n",
    "    \n",
    "    result_docs = []\n",
    "    \n",
    "    for hit in es_results['hits']['hits']:\n",
    "        result_docs.append(hit['_source'])\n",
    "\n",
    "    return result_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "7ba72a59-7d0e-4d61-90c4-008e6341f7dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23d4bec1617342b2a61ae3ff37e6a2f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4627 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'hit_rate': 0.9023125135076724, 'mrr': 0.804480945176861}"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def vector_combined_knn(q):\n",
    "    question = q['question']\n",
    "    course = q['course']\n",
    "\n",
    "    v_q = model.encode(question)\n",
    "\n",
    "    return elastic_search_knn_combined(v_q, course)\n",
    "\n",
    "evaluate(ground_truth, vector_combined_knn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "806798ff-431e-46a0-8ecd-edd9b688f8e3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
